eval_mcmc <- F

Simulation Validation

This notebook simulates several different sets of genealogies containing a single expnasion and tests MCMC inference against them, aiming to validate how well are parameters recovered across different scales.

Parameter Recovery

This section validates how well are model parameters recovered across different population size scales and different expansion rates

Set up the experiment

set.seed(1)

burn_in <- 0.1
n_it <- 1e7
thinning <- n_it/1e4  

repetitions <- 200
dims <- 1
lambdar_range <- c(5)
pop_range <- c(5)
dir.create(file.path(base_dir, "test_spec"))
dir.create(file.path(base_dir, "test_out"))

seed <- 0

for(i in c(1:length(lambdar_range))) {
  dir_name <- paste0("lambdar_",i)
  dir.create(file.path(paste0(base_dir,"/test_spec"), dir_name))
    dir.create(file.path(paste0(base_dir,"/test_out"), dir_name))
  for(j in c(1:length(pop_range))) {
    for (l in c(1:repetitions)) {
      params <- c(paste0("-e ", dims),
                  paste0("-s ", seed),
                  paste0("-t ", 100), 
                  paste0("-l ", lambdar_range[i]),
                  paste0("--meanscale ", pop_range[i]),
                  paste0("--sdscale ", 1/2),
                  paste0("--kappa ", 1/4),
                      paste0("--nu ", 1/3),
                  paste0("--sdk ", 0.5),
                  paste0("-o ",base_dir,"/test_out/",dir_name, "/pop_",j,"_r_",l),
                  paste0("--metadata ", paste0("\"","lambdar_idx:",i,",N_idx:",j,"\""))) 
      fname <- paste0("pop_",j,"_r_",l,".txt")
      fileConn<-file(paste0(base_dir,"/test_spec/", dir_name,"/", fname))
        writeLines(params, fileConn)
      close(fileConn)
      seed <- seed+1
    }
  }
}

Simulate the parameters of individual processes

find ./Paper/Sim/sim_one_exp | grep .*.txt | xargs -L 1 cat | parallel -j 10 --verbose --tmpdir ./Paper/Sim/tmp -n 11 --halt-on-error 2 eval Rscript "./Scripts/simulate_tree.R {}"

Run MCMC inference

find ./Paper/Sim/sim_one_exp | grep .*tree.nwk | xargs dirname | parallel -j 10 --verbose --tmpdir ./Paper/Sim/tmp --halt-on-error 2 eval Rscript "./Scripts/run_expansions_inference_nwk.R -n 1e7 -t 1e3 --lambdar 5 -s 1 -f {}/tree.nwk -o {}"

Produce plots for dimension mode analysis

compute_ci <- function(x, conf=0.95) {
  ci <- c()
  x_ord <- order(x)
  if(length(x)%%2==0) {
    l<-length(x)/2
    p1 <- x[x_ord][(l+1):length(x)]
    p2 <- x[x_ord][1:l]
  } else {
    l <-floor(length(x)/2)
    p1 <- x[x_ord][(l+2):length(x)]
    p2 <- x[x_ord][1:l]
  }
  ci[1] <- p2[l*(1-conf)]
  ci[2] <- p1[l*conf]
  return(ci)
}

experiment_specs <- list.files(path = paste0(base_dir,"/test_out"), pattern=".*tree_params.json", full.names = TRUE, recursive = TRUE)
experiment_dirs <- dirname(experiment_specs)

results_data <- lapply(c(1:length(experiment_specs)), function(i){
  sim_data <- fromJSON(file=experiment_specs[i])
  meta_data <- sim_data$meta
  sim_data <- lapply(sim_data, unlist) 
  sim_data <- lapply(sim_data, as.numeric) 

  dim.gt <- sim_data$n_exp
  t_mid.gt <- sim_data$t_mid
  K.gt <- sim_data$K
  N.gt <- sim_data$N
  time.gt <- sim_data$div_times
  br.gt <- sim_data$root_set
  time.gt <- time.gt[-length(time.gt)]  

  ### Check that we got the number of expansions approximately right  
  expansions <- readRDS(paste0(experiment_dirs[i],"/expansions.rds"))
  expansions <- discard_burn_in(expansions, proportion=burn_in)
  pre <- expansions$phylo_preprocessed

  mcmc_data <- expansions$model_data
  event_data <- expansions$expansion_data

  p_correct_dim <- length(which(mcmc_data$dim==dim.gt))/length(mcmc_data$dim)
  expected_dim <- sum(mcmc_data$dim)/length(mcmc_data$dim)

  unique.dim <- unique(mcmc_data$dim)
  mode_dim <- unique.dim[which.max(sapply(unique.dim, function(x) length(which(mcmc_data$dim == x))))]

  ### take the one dimensional marginal
  correct_dim <- mcmc_data[which(mcmc_data$dim==dim.gt),]
  correct_dim_it <- correct_dim$it
  event_dim_marginal <- event_data[unlist(sapply(correct_dim_it, function (x) which(event_data$it==x))),]

  if(nrow(event_dim_marginal) > 0) {
    ### Get branch mode, expected t_mid, K, T, N values, jaccard between mode expansion and correct expansion
    expected_N <- median(correct_dim$N)
    ci_N <- compute_ci(correct_dim$N)
    p_correct_br <- length(which(event_dim_marginal$br==br.gt))/length(event_dim_marginal$br)
    unique.br <- unique(event_dim_marginal$br)
    mode_branch <- unique.br[which.max(sapply(unique.br, function(x) length(which(event_dim_marginal$br == x))))]
    ### Take mode branch marginal 
    is_mode_correct <- mode_branch==br.gt

    mode_subs <- which(event_dim_marginal$br == mode_branch)
    p_mode_br <- length(mode_subs)/length(event_dim_marginal$br)

    event_br_marginal <- event_dim_marginal[mode_subs,]
    expected_t_mid <- median(event_br_marginal$t_mid)
    ci_t_mid <- compute_ci(event_br_marginal$t_mid)
    expected_K <- median(event_br_marginal$K)
    ci_K <- compute_ci(event_br_marginal$K)
    expected_T <- median(-event_br_marginal$time)
    ci_T <- compute_ci(-event_br_marginal$time)

    ### jaccard index
    mrca.gt <- pre$edges.df$node.child[br.gt]
    mrca.mode <- pre$edges.df$node.child[mode_branch]

    gt.tips <- pre$clades.list[[mrca.gt-pre$n_tips]]$tip.label
    mode.tips <- pre$clades.list[[mrca.mode-pre$n_tips]]$tip.label

    intersection <- sum(sapply(gt.tips, function(x) length(which(mode.tips==x))))

    jacc_dist <- 1-intersection / (length(gt.tips) + length(mode.tips) - intersection)
  } else {
    expected_N <- NA
    expected_K <- NA
    expected_t_mid <- NA
    jacc_dist <- NA 
    expected_T <- NA
    p_correct_br <- 0
    p_mode_br <- 0
  }


  return(list(p_correct_dim=p_correct_dim, 
              expected_dim=expected_dim, 
              mode_dim=mode_dim, 
              mode_branch=mode_branch,
              expected_N=expected_N,
              ci_N_lo=ci_N[1],
              ci_N_hi=ci_N[2],
              expected_T=expected_T, 
              ci_T_lo=ci_T[1],
              ci_T_hi=ci_T[2],
              expected_K=expected_K,
              ci_K_lo=ci_K[1],
              ci_K_hi=ci_K[2],
              expected_t_mid=expected_t_mid,
              ci_t_mid_lo=ci_t_mid[1],
              ci_t_mid_hi=ci_t_mid[2],
              p_correct_br=p_correct_br,
              p_mode_br=p_mode_br,
              is_mode_correct=is_mode_correct,
              jacc_dist=jacc_dist,
              t_mid_gt=t_mid.gt,
              dim_gt=dim.gt,
              K_gt=K.gt,
              N_gt=N.gt,
              time_gt=-time.gt,
              N_gt=N.gt,
              br_gt=br.gt))
})
names(results_data) <- c(1:length(results_data)) 
data_df <- do.call(rbind.data.frame, results_data)
data_df_dim <- data_df[which(data_df$mode_dim==1),]
data_df_m <- data_df_dim[which(data_df_dim$is_mode_correct),]
head(data_df)
dim_hist <- ggplot(data_df, aes(x=mode_dim))
dim_hist <- dim_hist + geom_bar(aes(y = ..prop..), stat="count") + 
            geom_text(aes( label = scales::percent(..prop..), y= ..prop.. ), stat= "count", vjust = -.5, size=15)
dim_hist <- dim_hist + theme_bw() + labs(x = "Mode No. of Expansions") + scale_y_continuous(labels=percent, limits=c(0,1)) 
dim_hist <- dim_hist + theme(axis.text.x = element_text(angle = 45, hjust = 1), 
                             axis.title.y = element_blank(),
                             text = element_text(size=35))

p_hist <- ggplot(data_df_dim, aes(x=p_correct_br))
p_hist <- p_hist + geom_histogram(aes(y = stat(count) / sum(count)), bins=40) + scale_y_continuous(labels=percent, limits=c(0,1)) 
p_hist <- p_hist + theme_bw() + labs(x ="Prob. Correct Branch") 
p_hist <- p_hist + theme(axis.text.x = element_text(angle = 45, hjust = 1),
                      axis.title.y = element_blank(), 
                      axis.text.y = element_blank(), 
                      axis.ticks.y = element_blank(),
                      text = element_text(size=35))

jacc_hist <- ggplot(data_df_dim, aes(x=jacc_dist))
jacc_hist <- jacc_hist + geom_histogram(aes(y = stat(count) / sum(count)), bins=40) + scale_y_continuous(labels=percent, limits=c(0,1)) 
jacc_hist <- jacc_hist + theme_bw() + labs(x ="Jaccard Distance")
jacc_hist <- jacc_hist + theme(axis.text.x = element_text(angle = 45, hjust = 1),
                      axis.title.y = element_blank(), 
                      axis.text.y = element_blank(), 
                      axis.ticks.y = element_blank(),
                      text = element_text(size=35))

p <- ggarrange(dim_hist, p_hist, jacc_hist, widths=c(2,2,2), heights=c(1))

png("./Paper/Figures/fig3a.png", width=1600,height=800)
p
dev.off()
## png 
##   2
lims <- c(min(data_df_m$ci_N_lo),max(data_df_m$ci_N_hi))
N_scatter <- ggplot(data_df_m, aes(x=N_gt, y=expected_N))
N_scatter <- N_scatter + geom_point()
N_scatter <- N_scatter + geom_errorbar(aes(ymin = ci_N_lo, ymax = ci_N_hi), width = 0.2)
N_scatter <- N_scatter + scale_x_continuous(trans="log10", limits=lims) + scale_y_continuous(trans="log10", limits=lims)
N_scatter <- N_scatter + theme_bw() + labs(x ="True Background Population", y="Median Background Population") + coord_fixed(ratio = 1)
N_scatter <- N_scatter + theme(axis.text.x = element_text(angle = 45, hjust = 1), text = element_text(size=35))
N_scatter <- N_scatter + geom_abline(intercept = 0, slope = 1)

lims <- c(min(data_df_m$ci_K_lo),max(data_df_m$ci_K_hi))
K_scatter <- ggplot(data_df_m, aes(x=K_gt, y=expected_K))
K_scatter <- K_scatter + geom_point()
K_scatter <- K_scatter + geom_errorbar(aes(ymin = ci_K_lo, ymax = ci_K_hi), width = 0.2)
K_scatter <- K_scatter + scale_x_continuous(trans="log10", limits=lims) + scale_y_continuous(trans="log10", limits=lims)
K_scatter <- K_scatter + theme_bw() + labs(x ="True Carrying Capacity", y="Median Carrying Capacity") + coord_fixed(ratio = 1)
K_scatter <- K_scatter + theme(axis.text.x = element_text(angle = 45, hjust = 1), text = element_text(size=35))
K_scatter <- K_scatter + geom_abline(intercept = 0, slope = 1)

lims <- c(min(data_df_m$ci_T_lo),max(data_df_m$ci_T_hi))
T_scatter <- ggplot(data_df_m, aes(x=time_gt, y=expected_T))
T_scatter <- T_scatter + geom_point()
T_scatter <- T_scatter + geom_errorbar(aes(ymin = ci_T_lo, ymax = ci_T_hi), width = 0.2)
T_scatter <- T_scatter + scale_x_continuous(trans="log10", limits=lims) + scale_y_continuous(trans="log10", limits=lims)
T_scatter <- T_scatter + theme_bw() + labs(x ="True Time of Expansions", y="Median Time of Expansions") + coord_fixed(ratio = 1)
T_scatter <- T_scatter + theme(axis.text.x = element_text(angle = 45, hjust = 1), text = element_text(size=35))
T_scatter <- T_scatter + geom_abline(intercept = 0, slope = 1)

lims <- c(min(data_df_m$ci_t_mid_lo),max(data_df_m$ci_t_mid_hi))
t_mid_scatter <- ggplot(data_df_m, aes(x=t_mid_gt, y=expected_t_mid))
t_mid_scatter <- t_mid_scatter + geom_point()
t_mid_scatter <- t_mid_scatter + geom_errorbar(aes(ymin = ci_t_mid_lo, ymax = ci_t_mid_hi), width = 0.2)
t_mid_scatter <- t_mid_scatter + scale_x_continuous(trans="log10", limits=lims) + scale_y_continuous(trans="log10", limits=lims)
t_mid_scatter <- t_mid_scatter + theme_bw() + labs(x ="True Time to Midpoint", y="Median Time to Midpoint") + coord_fixed(ratio = 1)
t_mid_scatter <- t_mid_scatter + theme(axis.text.x = element_text(angle = 45, hjust = 1), text = element_text(size=35))
t_mid_scatter <- t_mid_scatter + geom_abline(intercept = 0, slope = 1)

p <- ggarrange(N_scatter, K_scatter, t_mid_scatter, T_scatter, widths=c(2,2), heights=c(2,2))

png("./Paper/Figures/fig3b.png", width=1600,height=1600)
p
dev.off()
## png 
##   2

Done